import os
import os.path as osp
import random
import uuid
from typing import Dict, Optional, Sequence

import numpy as np
import torch
import yaml


def get_cost_push1(next_obs, start=12):
    hazards = next_obs[start+16:start+2*16]
    if hazards.max() >= 0.9:
        return 1
    else:
        return 0
def get_cost_push1_car(next_obs, start=24):
    hazards = next_obs[start+16:start+2*16]
    if hazards.max() >= 0.9:
        return 1
    else:
        return 0
def get_cost_push2(next_obs, start=12):
    hazards = next_obs[start+16:start+2*16]
    pillars = next_obs[start+2*16:start+3*16]
    if hazards.max() >= 0.9 or pillars.max()>=0.88:
        return 1
    else:
        return 0
def get_cost_push2_car(next_obs, start=24):
    hazards = next_obs[start+16:start+2*16]
    pillars = next_obs[start+2*16:start+3*16]
    if hazards.max() >= 0.9 or pillars.max()>=0.85:
        return 1
    else:
        return 0
def get_cost_goal1(next_obs, start=12):
    hazards = next_obs[start+16:start+2*16]
    if hazards.max() >= 0.9:
        return 1
    else:
        return 0
def get_cost_goal1_car(next_obs, start=24):
    hazards = next_obs[start+16:start+2*16]
    if hazards.max() >= 0.9:
        return 1
    else:
        return 0
def get_cost_goal2(next_obs, start=12):
    hazards = next_obs[start+16:start+2*16]
    vases = next_obs[start+2*16:start+3*16]
    if hazards.max() >= 0.9 or vases.max()>=0.88:
        return 1
    else:
        return 0
def get_cost_goal2_car(next_obs, start=24):
    hazards = next_obs[start+16:start+2*16]
    vases = next_obs[start+2*16:start+3*16]
    if hazards.max() >= 0.9 or vases.max()>=0.85:
        return 1
    else:
        return 0
    

def get_cost_circle1(next_obs, start=12):
    circle = next_obs[start:start+16]
    max_circle = circle.max()
    d = (1-max_circle) * 6
    if np.abs(d) >= 1.0:
        return 1
    else:
        return 0
def get_cost_circle1_car(next_obs, start=24):
    circle = next_obs[start:start+16]
    max_circle = circle.max()
    d = (1-max_circle) * 6
    if np.abs(d) >= 1.0:
        return 1
    else:
        return 0
def get_cost_circle1_llm(observation):
    """
    Returns:
        0 if the observation is considered safe (lateral distance ≤ 1.125 m under our conservative estimate),
        1 otherwise.
    
    observation: any sequence (e.g. list or numpy array) whose last 16 entries are the circle radar readings.
    Radar reading r in a bin encodes distance d = 6 - 6*r (for 0 ≤ d ≤ 6).
    """
    # extract the 16 radar bins
    radar = observation[-16:]
    # the highest bin reading gives the best (smallest) distance estimate
    m = max(radar)
    # convert back to an estimated distance
    d_est = 6.0 - 6.0 * m
    # if estimated distance > 1.125, flag unsafe
    return 1 if d_est > 1.125 else 0
def get_cost_circle2_llm(obs) -> int:
    """
    Returns 0 if the observation is considered safe, 1 otherwise.

    Safety rule:
      • If the lateral (x) or vertical (y) offset from (0,0) exceeds 1.125 m, it's unsafe.
      • We only have a 16-bin radar (last 16 entries of obs) giving value v=(6-d)/6
        for the single true target (the origin).  We approximate d by:
              d = 6 * (1 - max_bin_value)
      • Any d > 1.125 must mean at least one axis offset >1.125, so we flag unsafe.
      • For extra conservatism you can lower `threshold` below 1.125.
    """
    # extract the 16 radar bins (the very last 16 elements)
    radar_bins = obs[-16:]
    # strongest return => smallest distance
    max_v = max(radar_bins)
    # recover range d to origin
    d = 6 * (1.0 - max_v)

    # safety threshold (m).  Lower this if you want to be extra conservative.
    threshold = 1.125

    return 1 if d > threshold else 0

def get_cost_button(next_obs, start=12):
    button = next_obs[start:start+16]
    goal = next_obs[start+16:start+2*16]
    # goal = next_obs[12:28]
    # button = next_obs[28:44]
    gremlin = next_obs[start+2*16:start+3*16]
    hazard = next_obs[start+3*16:start+4*16]
    # goal_button = (button!=goal)
    # button_error = button[goal_button]
    button_error = button
    if (button_error.shape[0]!=0 and button_error.max() >= 0.93) or hazard.max() > 0.9 or gremlin.max() >= 0.93:
        return 1
    else:
        return 0
    
def get_cost_button_car(next_obs, start=24):
    button = next_obs[start:start+16]
    goal = next_obs[start+16:start+2*16]
    # goal = next_obs[12:28]
    # button = next_obs[28:44]
    gremlin = next_obs[start+2*16:start+3*16]
    hazard = next_obs[start+3*16:start+4*16]
    # goal_button = (button!=goal)
    # button_error = button[goal_button]
    button_error = button
    if (button_error.shape[0]!=0 and button_error.max() >= 0.9) or hazard.max() > 0.9 or gremlin.max() >= 0.9:
        return 1
    else:
        return 0

def parse_radars(observation):
    """
    Split the last 64 dimensions of the observation into four 16-dim radar vectors:
      • button radar  : obs[-64:-48]
      • goal radar    : obs[-48:-32]   (ignored by get_cost)
      • gremlin radar : obs[-32:-16]
      • hazard radar  : obs[-16:   ]
    """
    obs = np.asarray(observation)
    if obs.size < 64:
        raise ValueError("Observation must have at least 64 dimensions")
    button  = obs[-64:-48]
    goal    = obs[-48:-32]
    gremlin = obs[-32:-16]
    hazard  = obs[-16:  ]
    return button, goal, gremlin, hazard

def radar_distances(radar_values):
    """
    Convert radar readings r = (3-d)/3 back to distances d for r>0.
    Returns an array of distances for those directions where r>0.
    """
    detected = radar_values > 0
    return 3.0 * (1.0 - radar_values[detected])

def get_cost_button_llm(observation,
             button_margin=0.35,
             gremlin_margin=0.35,
             hazard_margin=0.30):
    """
    Returns 0 if safe, 1 if unsafe.

    Unsafe if:
      1) Any button within `button_margin` meters
      2) Any gremlin within `gremlin_margin` meters
      3) Any hazard within `hazard_margin` meters

    Margins are slightly conservative over the physical radii (0.3 m).
    """
    button_r, _, gremlin_r, hazard_r = parse_radars(observation)

    # 1) button too close?
    d_btn = radar_distances(button_r)
    if d_btn.size > 0 and np.any(d_btn <= button_margin):
        return 1

    # 2) gremlin too close?
    d_gremlin = radar_distances(gremlin_r)
    if d_gremlin.size > 0 and np.any(d_gremlin <= gremlin_margin):
        return 1

    # 3) hazard too close?
    d_hazard = radar_distances(hazard_r)
    if d_hazard.size > 0 and np.any(d_hazard <= hazard_margin):
        return 1

    return 0

env2cost_dict = {
    "OfflineCarButton1Gymnasium-v0": get_cost_button_llm,           # 0
    "OfflineCarButton2Gymnasium-v0": get_cost_button_llm,           # 1
    "OfflineCarCircle1Gymnasium-v0": get_cost_circle1_llm,           # 2
    "OfflineCarCircle2Gymnasium-v0": get_cost_circle2_llm,           # 3
    "OfflineCarGoal1Gymnasium-v0": get_cost_goal1_car,             # 4
    "OfflineCarGoal2Gymnasium-v0": get_cost_goal2_car,             # 5
    "OfflineCarPush1Gymnasium-v0": get_cost_push1_car,             # 6
    "OfflineCarPush2Gymnasium-v0": get_cost_push2_car,             # 7
    # safety_gymnasium: point
    "OfflinePointButton1Gymnasium-v0": get_cost_button_llm,         # 8
    "OfflinePointButton2Gymnasium-v0": get_cost_button_llm,         # 9
    "OfflinePointCircle1Gymnasium-v0": get_cost_circle1_llm,         # 10
    "OfflinePointCircle2Gymnasium-v0": get_cost_circle2_llm,         # 11
    "OfflinePointGoal1Gymnasium-v0": get_cost_goal1,           # 12
    "OfflinePointGoal2Gymnasium-v0": get_cost_goal2,           # 13
    "OfflinePointPush1Gymnasium-v0": get_cost_push1,           # 14
    "OfflinePointPush2Gymnasium-v0": get_cost_push2,           # 15
}



def seed_all(seed=1029, others: Optional[list] = None):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    # torch.use_deterministic_algorithms(True)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    if others is not None:
        if hasattr(others, "seed"):
            others.seed(seed)
            return True
        try:
            for item in others:
                if hasattr(item, "seed"):
                    item.seed(seed)
        except:
            pass


def get_cfg_value(config, key):
    if key in config:
        value = config[key]
        if isinstance(value, list):
            suffix = ""
            for i in value:
                suffix += str(i)
            return suffix
        return str(value)
    for k in config.keys():
        if isinstance(config[k], dict):
            res = get_cfg_value(config[k], key)
            if res is not None:
                return res
    return "None"


def load_config_and_model(path: str, best: bool = False, device=None):
    '''
    Load the configuration and trained model from a specified directory.

    :param path: the directory path where the configuration and trained model are stored.
    :param best: whether to load the best-performing model or the most recent one. Defaults to False.

    :return: a tuple containing the configuration dictionary and the trained model.
    :raises ValueError: if the specified directory does not exist.
    '''
    if osp.exists(path):
        config_file = osp.join(path, "config.yaml")
        print(f"load config from {config_file}")
        with open(config_file) as f:
            config = yaml.load(f.read(), Loader=yaml.FullLoader)
        model_file = "model.pt"
        if best:
            model_file = "model_best.pt"
        model_path = osp.join(path, "checkpoint/" + model_file)
        print(f"load model from {model_path}")
        if device is None:
            model = torch.load(model_path)
        else:
            model = torch.load(model_path, map_location=device)
        return config, model
    else:
        raise ValueError(f"{path} doesn't exist!")


def to_string(values):
    '''
    Recursively convert a sequence or dictionary of values to a string representation.
    :param values: the sequence or dictionary of values to be converted to a string.
    :return: a string representation of the input values.
    '''
    name = ""
    if isinstance(values, Sequence) and not isinstance(values, str):
        for i, v in enumerate(values):
            prefix = "" if i == 0 else "_"
            name += prefix + to_string(v)
        return name
    elif isinstance(values, Dict):
        for i, k in enumerate(sorted(values.keys())):
            prefix = "" if i == 0 else "_"
            name += prefix + to_string(values[k])
        return name
    else:
        return str(values)


DEFAULT_SKIP_KEY = [
    "task", "reward_threshold", "logdir", "worker", "project", "group", "name", "prefix",
    "suffix", "save_interval", "render", "verbose", "save_ckpt", "training_num",
    "testing_num", "epoch", "device", "thread"
]

DEFAULT_KEY_ABBRE = {
    "cost_limit": "cost",
    "mstep_iter_num": "mnum",
    "estep_iter_num": "enum",
    "estep_kl": "ekl",
    "mstep_kl_mu": "kl_mu",
    "mstep_kl_std": "kl_std",
    "mstep_dual_lr": "mlr",
    "estep_dual_lr": "elr",
    "update_per_step": "update"
}


def auto_name(default_cfg: dict,
              current_cfg: dict,
              prefix: str = "",
              suffix: str = "",
              skip_keys: list = DEFAULT_SKIP_KEY,
              key_abbre: dict = DEFAULT_KEY_ABBRE) -> str:
    '''
    Automatic generate the experiment name by comparing the current config with the default one.

    :param dict default_cfg: a dictionary containing the default configuration values.
    :param dict current_cfg: a dictionary containing the current configuration values.
    :param str prefix: (optional) a string to be added at the beginning of the generated name.
    :param str suffix: (optional) a string to be added at the end of the generated name.
    :param list skip_keys: (optional) a list of keys to be skipped when generating the name.
    :param dict key_abbre: (optional) a dictionary containing abbreviations for keys in the generated name.

    :return str: a string representing the generated experiment name.
    '''
    name = prefix
    for i, k in enumerate(sorted(default_cfg.keys())):
        if default_cfg[k] == current_cfg[k] or k in skip_keys:
            continue
        prefix = "_" if len(name) else ""
        value = to_string(current_cfg[k])
        # replace the name with abbreviation if key has abbreviation in key_abbre
        if k in key_abbre:
            k = key_abbre[k]
        # Add the key-value pair to the name variable with the prefix
        name += prefix + k + value
    if len(suffix):
        name = name + "_" + suffix if len(name) else suffix

    name = "default" if not len(name) else name
    name = f"{name}-{str(uuid.uuid4())[:4]}"
    return name

def visualization(obs_seq, name="", path = "", model = None, step = 0, returns = None, test_list = None):
    """
        For Circle tasks only
        obs_seq: [horizon, obs_dim]
        model: a diffuser model
    """
    plt.figure(3)
    tnp = np.array([0, 1, 
                    int(model.n_timesteps * 1 // 4),
                    int(model.n_timesteps * 2 // 4),
                    int(model.n_timesteps * 3 // 4),
                    model.n_timesteps-1,
                    ])
    fig, ax = plt.subplots(ncols=tnp.shape[0], nrows=2, figsize=(30,7))
    
    horizon, dim = obs_seq.shape[0], obs_seq.shape[1]

    x_seq, y_seq = obs_seq[:, 0], obs_seq[:, 1]
    x_seq = x_seq.cpu().numpy()
    y_seq = y_seq.cpu().numpy()

    x_seq *= 1
    y_seq *= 1

    init_x, init_y = x_seq[:1], y_seq[:1]

    

    if model is not None:
        x_start = obs_seq.unsqueeze(0) # [1, Horizon, dim]

        test_list = [0.2, 0.5, 0.7, 0.9]
        
        cond = {0: x_start[:, 0]} # 
        noise = torch.randn(size=(1, horizon, dim), device=obs_seq.device) #  noise: [Batch_size, Horizon, dim]
         # model.n_timesteps-1, model.n_timesteps // 2, 
        t = torch.tensor(tnp, device=obs_seq.device)
        x_noisy = model.q_sample(x_start=x_start, t=t, noise=noise)
        x_noisy = apply_conditioning(x_noisy, cond, 0) 
        x_noisy = x_noisy.cpu().numpy()
        x_seq_noisy, y_seq_noisy = x_noisy[:, :, 0], x_noisy[:, :, 1]

        conditions = {0: to_torch(x_start[:, 0], device=model.betas.device)}
        

        for noise_step in range(x_seq_noisy.shape[0]):
            # forward
            ax[0][noise_step].set_xlim([-2, 2])
            ax[0][noise_step].set_ylim([-2, 2])
            ax[0][noise_step].scatter(x_seq, y_seq, label="original dataset")
            ax[0][noise_step].scatter(init_x, init_y, color='r', label="initial point", marker="*", linewidths=1)
            t_step = tnp[noise_step]
            ax[0][noise_step].scatter(x_seq_noisy[noise_step] , 
                        y_seq_noisy[noise_step]  , 
                        label="noise step {}".format(t_step),
                        alpha=0.35)
            ax[0][noise_step].scatter(x_seq_noisy[noise_step, 0]  , 
                        y_seq_noisy[noise_step, 0]  , 
                        label="noise step {}, init".format(t_step),
                        alpha=0.35, marker="X")
            ax[0][noise_step].set_title(f"forward t={t_step}")
            ax[0][noise_step].legend()

            # inverse
            

            ax[1][noise_step].set_xlim([-2, 2])
            ax[1][noise_step].set_ylim([-2, 2])
            ax[1][noise_step].scatter(x_seq, y_seq, label="original dataset")
            ax[1][noise_step].scatter(init_x, init_y, color='r', label="initial point", marker="*", linewidths=1)

            for test_cond in test_list:
                test_ret = test_cond
                samples, diffusion = model.conditional_sample(conditions,
                                                        return_diffusion=True,
                                                        returns=0.75,
                                                        cost_returns=test_ret) # Why return is None?
                
                diffusion_np = diffusion.squeeze(axis=0).cpu().numpy()
                diffusion_np = np.flip(diffusion_np, 0)
                
                data = diffusion_np[t_step]
                denoise_x, denoise_y = data[:, 0], data[:, 1]
                ax[1][noise_step].scatter(denoise_x, 
                            denoise_y, 
                            label="return {}, denoise step {}".format(test_ret, t_step),
                            alpha=0.35)
                # ax[1][noise_step].scatter(denoise_x[0], 
                #             denoise_y[0], 
                #             label="denoise step {}, init".format(t_step),
                #             alpha=0.35, marker="X")
            ax[1][noise_step].set_title(f"inverse t={t_step}")
            ax[1][noise_step].legend()

    plt.savefig(path+"Step-"+str(step)+"-"+name+".png", dpi=400)
    # print(model.betas)

    return

def visualization_seqlen(obs_seq, 
                         name="", 
                         path = "", 
                         model = None, 
                         step = 0, 
                         returns = None, 
                         seq_len_list = None):
    """
        For Circle tasks only
        obs_seq: [horizon, obs_dim]
        model: a diffuser model
    """
    plt.figure(3)
    tnp = np.array([0, 1, 
                    int(model.n_timesteps * 1 // 4),
                    int(model.n_timesteps * 2 // 4),
                    int(model.n_timesteps * 3 // 4),
                    model.n_timesteps-1,
                    ])
    fig, ax = plt.subplots(ncols=tnp.shape[0], nrows=2, figsize=(30,7))
    
    horizon, dim = obs_seq.shape[0], obs_seq.shape[1]

    x_seq, y_seq = obs_seq[:, 0], obs_seq[:, 1]
    x_seq = x_seq.cpu().numpy()
    y_seq = y_seq.cpu().numpy()

    x_seq *= 1
    y_seq *= 1

    init_x, init_y = x_seq[:1], y_seq[:1]


    if model is not None:
        x_start = obs_seq.unsqueeze(0) # [1, Horizon, dim]

        test_list = [0.7]
        
        cond = {0: x_start[:, 0]} # 
        noise = torch.randn(size=(1, horizon, dim), device=obs_seq.device) #  noise: [Batch_size, Horizon, dim]
        
        t = torch.tensor(tnp, device=obs_seq.device)
        x_noisy = model.q_sample(x_start=x_start, t=t, noise=noise)
        x_noisy = apply_conditioning(x_noisy, cond, 0) 
        x_noisy = x_noisy.cpu().numpy()
        x_seq_noisy, y_seq_noisy = x_noisy[:, :, 0], x_noisy[:, :, 1]

        conditions = {0: to_torch(x_start[:, 0], device=model.betas.device)}
        

        for noise_step in range(x_seq_noisy.shape[0]):
            # forward
            ax[0][noise_step].set_xlim([-2, 2])
            ax[0][noise_step].set_ylim([-2, 2])
            ax[0][noise_step].scatter(x_seq, y_seq, label="original dataset")
            ax[0][noise_step].scatter(init_x, init_y, color='r', label="initial point", marker="*", linewidths=1)
            t_step = tnp[noise_step]
            ax[0][noise_step].scatter(x_seq_noisy[noise_step] , 
                        y_seq_noisy[noise_step]  , 
                        label="noise step {}".format(t_step),
                        alpha=0.35)
            ax[0][noise_step].scatter(x_seq_noisy[noise_step, 0]  , 
                        y_seq_noisy[noise_step, 0]  , 
                        label="noise step {}, init".format(t_step),
                        alpha=0.35, marker="X")
            ax[0][noise_step].set_title(f"forward t={t_step}")
            ax[0][noise_step].legend()

            # inverse
            

            ax[1][noise_step].set_xlim([-2, 2])
            ax[1][noise_step].set_ylim([-2, 2])
            ax[1][noise_step].scatter(x_seq, y_seq, label="original dataset")
            ax[1][noise_step].scatter(init_x, init_y, color='r', label="initial point", marker="*", linewidths=1)

            for test_cond in test_list:
                test_ret = test_cond
                samples, diffusion = model.conditional_sample(conditions,
                                                        return_diffusion=True,
                                                        returns=test_ret,
                                                        cost_returns=test_ret) # Why return is None?
                
                diffusion_np = diffusion.squeeze(axis=0).cpu().numpy()
                diffusion_np = np.flip(diffusion_np, 0)
                
                data = diffusion_np[t_step]
                denoise_x, denoise_y = data[:, 0], data[:, 1]
                ax[1][noise_step].scatter(denoise_x, 
                            denoise_y, 
                            label="return {}, denoise step {}".format(test_ret, t_step),
                            alpha=0.35)
                # ax[1][noise_step].scatter(denoise_x[0], 
                #             denoise_y[0], 
                #             label="denoise step {}, init".format(t_step),
                #             alpha=0.35, marker="X")
            ax[1][noise_step].set_title(f"inverse t={t_step}")
            ax[1][noise_step].legend()

    plt.savefig(path+"Step-"+str(step)+"-"+name+".png", dpi=400)
    # print(model.betas)

    return
